//
// Copyright Contributors to the MaterialX Project
// SPDX-License-Identifier: Apache-2.0
//

#ifndef MATERIALX_SHADERGRAPH_H
#define MATERIALX_SHADERGRAPH_H

/// @file
/// Shader graph class

#include <MaterialXGenShader/Export.h>

#include <MaterialXGenShader/ColorManagementSystem.h>
#include <MaterialXGenShader/ShaderNode.h>
#include <MaterialXGenShader/Syntax.h>
#include <MaterialXGenShader/TypeDesc.h>
#include <MaterialXGenShader/UnitSystem.h>

#include <MaterialXCore/Document.h>
#include <MaterialXCore/Node.h>

MATERIALX_NAMESPACE_BEGIN

class Syntax;
class ShaderGraphEdge;
class ShaderGraphEdgeIterator;
class GenOptions;

/// An internal input socket in a shader graph,
/// used for connecting internal nodes to the outside
using ShaderGraphInputSocket = ShaderOutput;

/// An internal output socket in a shader graph,
/// used for connecting internal nodes to the outside
using ShaderGraphOutputSocket = ShaderInput;

/// A shared pointer to a shader graph
using ShaderGraphPtr = shared_ptr<class ShaderGraph>;

/// @class ShaderGraph
/// Class representing a graph (DAG) for shader generation
class MX_GENSHADER_API ShaderGraph : public ShaderNode
{
  public:
    /// Constructor.
    ShaderGraph(const ShaderGraph* parent, const string& name, ConstDocumentPtr document);

    /// Destructor.
    virtual ~ShaderGraph() { }

    /// Create a new shader graph from an element.
    /// Supported elements are outputs and shader nodes.
    static ShaderGraphPtr create(const ShaderGraph* parent, const string& name, ElementPtr element,
                                 GenContext& context);

    /// Create a new shader graph from a nodegraph.
    static ShaderGraphPtr create(const ShaderGraph* parent, const NodeGraph& nodeGraph,
                                 GenContext& context);

    /// Return true if this node is a graph.
    bool isAGraph() const override { return true; }

    /// Get an internal node by name
    ShaderNode* getNode(const string& name);

    /// Get an internal node by name
    const ShaderNode* getNode(const string& name) const;

    /// Get a vector of all nodes in order
    const vector<ShaderNode*>& getNodes() const { return _nodeOrder; }

    /// Get number of input sockets
    size_t numInputSockets() const { return numOutputs(); }

    /// Get number of output sockets
    size_t numOutputSockets() const { return numInputs(); }

    /// Get socket by index
    ShaderGraphInputSocket* getInputSocket(size_t index) { return getOutput(index); }
    ShaderGraphOutputSocket* getOutputSocket(size_t index = 0) { return getInput(index); }
    const ShaderGraphInputSocket* getInputSocket(size_t index) const { return getOutput(index); }
    const ShaderGraphOutputSocket* getOutputSocket(size_t index = 0) const { return getInput(index); }

    /// Get socket by name
    ShaderGraphInputSocket* getInputSocket(const string& name) { return getOutput(name); }
    ShaderGraphOutputSocket* getOutputSocket(const string& name) { return getInput(name); }
    const ShaderGraphInputSocket* getInputSocket(const string& name) const { return getOutput(name); }
    const ShaderGraphOutputSocket* getOutputSocket(const string& name) const { return getInput(name); }

    /// Get vector of sockets
    const vector<ShaderGraphInputSocket*>& getInputSockets() const { return _outputOrder; }
    const vector<ShaderGraphOutputSocket*>& getOutputSockets() const { return _inputOrder; }

    /// Apply color and unit transforms to each input of a node.
    void applyInputTransforms(ConstNodePtr node, ShaderNode* shaderNode, GenContext& context);

    /// Create a new node in the graph
    ShaderNode* createNode(ConstNodePtr node, GenContext& context);

    ShaderNode* inlineNodeBeforeOutput(ShaderGraphOutputSocket* output,
                                        const std::string& newNodeName,
                                        const std::string& nodeDefName,
                                        const std::string& inputName,
                                        const std::string& outputName,
                                        GenContext& context);

    /// Add input sockets
    ShaderGraphInputSocket* addInputSocket(const string& name, TypeDesc type);
    [[deprecated]] ShaderGraphInputSocket* addInputSocket(const string& name, const TypeDesc* type) { return addInputSocket(name, *type); }

    /// Add output sockets
    ShaderGraphOutputSocket* addOutputSocket(const string& name, TypeDesc type);
    [[deprecated]] ShaderGraphOutputSocket* addOutputSocket(const string& name, const TypeDesc* type) { return addOutputSocket(name, *type); }

    /// Add a default geometric node and connect to the given input.
    void addDefaultGeomNode(ShaderInput* input, const GeomPropDef& geomprop, GenContext& context);

    /// Sort the nodes in topological order.
    void topologicalSort();

    /// Return an iterator for traversal upstream from the given output
    static ShaderGraphEdgeIterator traverseUpstream(ShaderOutput* output);

    /// Return the map of unique identifiers used in the scope of this graph.
    IdentifierMap& getIdentifierMap() { return _identifiers; }

  protected:
    /// Create node connections corresponding to the connection between a pair of elements.
    /// @param downstreamElement Element representing the node to connect to.
    /// @param upstreamElement Element representing the node to connect from
    /// @param connectingElement If non-null, specifies the element on on the downstream node to connect to.
    /// @param context Context for generation.
    void createConnectedNodes(const ElementPtr& downstreamElement,
                              const ElementPtr& upstreamElement,
                              ElementPtr connectingElement,
                              GenContext& context);

    /// Create a new node in a graph from a node definition.
    /// Note - this does not initialize the node instance with any concrete values, but
    /// instead creates an empty instance of the provided node definition
    ShaderNode* createNode(const string& name, ConstNodeDefPtr nodeDef, GenContext& context);

    /// Add a node to the graph
    void addNode(ShaderNodePtr node);

    /// Add input sockets from an interface element (nodedef, nodegraph or node)
    void addInputSockets(const InterfaceElement& elem, GenContext& context);

    /// Add output sockets from an interface element (nodedef, nodegraph or node)
    void addOutputSockets(const InterfaceElement& elem, GenContext& context);

    /// Traverse from the given root element and add all dependencies upstream.
    /// The traversal is done in the context of a material, if given, to include
    /// bind input elements in the traversal.
    void addUpstreamDependencies(const Element& root, GenContext& context);

    /// Add a color transform node and connect to the given input.
    void addColorTransformNode(ShaderInput* input, const ColorSpaceTransform& transform, GenContext& context);

    /// Add a color transform node and connect to the given output.
    void addColorTransformNode(ShaderOutput* output, const ColorSpaceTransform& transform, GenContext& context);

    /// Add a unit transform node and connect to the given input.
    void addUnitTransformNode(ShaderInput* input, const UnitTransform& transform, GenContext& context);

    /// Add a unit transform node and connect to the given output.
    void addUnitTransformNode(ShaderOutput* output, const UnitTransform& transform, GenContext& context);

    /// Perform all post-build operations on the graph.
    void finalize(GenContext& context);

    /// Optimize the graph, removing redundant paths.
    void optimize(GenContext& context);

    /// Bypass a node for a particular input and output,
    /// effectively connecting the input's upstream connection
    /// with the output's downstream connections.
    void bypass(ShaderNode* node, size_t inputIndex, size_t outputIndex = 0);

    /// For inputs and outputs in the graph set the variable names to be used
    /// in generated code. Making sure variable names are valid and unique
    /// to avoid name conflicts during shader generation.
    void setVariableNames(GenContext& context);

    /// Populate the color transform map for the given shader port, if the provided combination of
    /// source and target color spaces are supported for its data type.
    void populateColorTransformMap(ColorManagementSystemPtr colorManagementSystem, ShaderPort* shaderPort,
                                   const string& sourceColorSpace, const string& targetColorSpace, bool asInput);

    /// Populates the appropriate unit transform map if the provided input/parameter or output
    /// has a unit attribute and is of the supported type
    void populateUnitTransformMap(UnitSystemPtr unitSystem, ShaderPort* shaderPort, ValueElementPtr element, const string& targetUnitSpace, bool asInput);

    /// Break all connections on a node
    void disconnect(ShaderNode* node) const;

    ConstDocumentPtr _document;
    std::unordered_map<string, ShaderNodePtr> _nodeMap;
    std::vector<ShaderNode*> _nodeOrder;
    IdentifierMap _identifiers;

    // Temporary storage for inputs that require color transformations
    std::vector<std::pair<ShaderInput*, ColorSpaceTransform>> _inputColorTransformMap;
    // Temporary storage for inputs that require unit transformations
    std::vector<std::pair<ShaderInput*, UnitTransform>> _inputUnitTransformMap;

    // Temporary storage for outputs that require color transformations
    std::vector<std::pair<ShaderOutput*, ColorSpaceTransform>> _outputColorTransformMap;
    // Temporary storage for outputs that require unit transformations
    std::vector<std::pair<ShaderOutput*, UnitTransform>> _outputUnitTransformMap;
};

/// @class ShaderGraphEdge
/// An edge returned during shader graph traversal.
class MX_GENSHADER_API ShaderGraphEdge
{
  public:
    ShaderGraphEdge(ShaderOutput* up, ShaderInput* down) :
        upstream(up),
        downstream(down)
    {
    }

    bool operator==(const ShaderGraphEdge& rhs) const
    {
        return upstream == rhs.upstream && downstream == rhs.downstream;
    }

    bool operator!=(const ShaderGraphEdge& rhs) const
    {
        return !(*this == rhs);
    }

    bool operator<(const ShaderGraphEdge& rhs) const
    {
        return std::tie(upstream, downstream) < std::tie(rhs.upstream, rhs.downstream);
    }

    ShaderOutput* upstream;
    ShaderInput* downstream;
};

/// @class ShaderGraphEdgeIterator
/// Iterator class for traversing edges between nodes in a shader graph.
class MX_GENSHADER_API ShaderGraphEdgeIterator
{
  public:
    ShaderGraphEdgeIterator(ShaderOutput* output);
    ~ShaderGraphEdgeIterator() = default;

    bool operator==(const ShaderGraphEdgeIterator& rhs) const
    {
        return _upstream == rhs._upstream &&
               _downstream == rhs._downstream &&
               _stack == rhs._stack;
    }
    bool operator!=(const ShaderGraphEdgeIterator& rhs) const
    {
        return !(*this == rhs);
    }

    /// Dereference this iterator, returning the current output in the traversal.
    ShaderGraphEdge operator*() const
    {
        return ShaderGraphEdge(_upstream, _downstream);
    }

    /// Iterate to the next edge in the traversal.
    /// @throws ExceptionFoundCycle if a cycle is encountered.
    ShaderGraphEdgeIterator& operator++();

    /// Return a reference to this iterator to begin traversal
    ShaderGraphEdgeIterator& begin()
    {
        return *this;
    }

    /// Return the end iterator.
    static const ShaderGraphEdgeIterator& end();

  private:
    void extendPathUpstream(ShaderOutput* upstream, ShaderInput* downstream);
    void returnPathDownstream(ShaderOutput* upstream);
    bool skipOrMarkAsVisited(ShaderGraphEdge);

    ShaderOutput* _upstream;
    ShaderInput* _downstream;
    using StackFrame = std::pair<ShaderOutput*, size_t>;
    std::vector<StackFrame> _stack;
    std::set<ShaderOutput*> _path;
    std::set<ShaderGraphEdge> _visitedEdges;
};

MATERIALX_NAMESPACE_END

#endif
